### Fit a model where we calculate the contest for all participants but only include
### original dispute participants in the bargaining stage

date()
library("tidyverse")
library("assertr")
library("foreach")
source("backend_joiners.r")
sessionInfo()


load("kr_analysis_dispute.rda")
load("kr_analysis_participant.rda")
load("confirm_fit_base.rda")
load("confirm_fit_joiners.rda")
raw_mid_b <- read_csv("gml-midb-2.1.1.csv",
                      col_types = cols(
                          dispnum4 = col_integer()
                      ),
                      na = c("-9", "NA"))


## Set to FALSE to estimate from scratch (takes much longer)
CONFIRM_MODE <- TRUE

## Extract originator variable from raw MID data
data_orig <- raw_mid_b %>%
  select(id = dispnum, ccode = ccode, orig = orig) %>%
  group_by(id, ccode) %>%
  summarise(orig = unique(orig)) %>%
  ungroup %>%
  verify(!duplicated(paste(id, ccode)))

## Merge originator status into participant data
data_participant <- foreach (imp = data_participant) %do% {
  imp %>%
    left_join(data_orig, by = c("id", "ccode")) %>%
    verify(!duplicated(paste(id, ccode)))
}

## Ensure each dispute has at least one originator on each side
for (i in seq_along(data_participant)) {
  data_participant[[i]] %>%
    group_by(id, sidea) %>%
    summarise(n_orig = sum(orig)) %>%
    verify(!is.na(n_orig)) %>%
    verify(n_orig >= 1) %>%
    invisible()
}

## Only use originators for dispute-level variables
data_dispute <- foreach (i = seq_along(data_dispute)) %do% {
  dp <- data_participant[[i]] %>%
    filter(orig == 1) %>%
    group_by(id) %>%
    summarise(polity_a = mean(polity2[sidea == 1]),
              polity_b = mean(polity2[sidea == 0]),
              majpow_a = max(majpow[sidea == 1]),
              majpow_b = max(majpow[sidea == 0]),
              n_states_a = sum(sidea == 1),
              n_states_b = sum(sidea == 0))
  dd <- data_dispute[[i]] %>%
    select(-polity_a, -polity_b, -majpow_a, -majpow_b, -n_states_a, -n_states_b)
  left_join(dd, dp, by = "id")
}

## Model formulas
f_dispute <- id + war + win_a + win_b_alt ~
  polity_a + majpow_a + log1p(py_alt) + s_cinc + contig |
  polity_b + majpow_b + log1p(py_alt) + s_cinc + contig |
  log(n_states_a) | log(n_states_b)
f_participant <- id + sidea + orig ~
  log(gdp_pwt) + log1p(irst) + log1p(pec) + log1p(tpop) + log1p(upop) +
  log1p(distance) + nuclear | log1p(pct_imports) + polity2

## Start values
if (CONFIRM_MODE) {
  init <- confirm_fit_joiners
} else {
  init <- map(confirm_fit_base, ~ .x[str_detect(names(.x), "^(beta|gamma)")])
}

## Fitting process
fit_joiners <- est_structwar_ac(f_dispute = f_dispute,
                                f_participant = f_participant,
                                data_dispute = data_dispute,
                                data_participant = data_participant,
                                n_halton = 1024,
                                init = init,
                                scale = TRUE,
                                reltol = 1e-14,
                                iterlim = 5000,
                                printLevel = 1)

## Confirm that coefficients match
if (CONFIRM_MODE) {
  for (i in seq_along(confirm_fit_joiners)) {
    stopifnot(all.equal(coef(fit_joiners[[i]]$fit), confirm_fit_joiners[[i]]))
  }
}

## Seeds generated 2022-07-28 by random.org
seeds <- c(760686, 305122, 963826, 706100, 90654,
           150657, 968851, 15716, 658576, 762687)
boot_index_joiners <- vector("list", length(fit_joiners))
for (i in seq_along(fit_joiners)) {
  set.seed(seeds[i])
  fit <- fit_joiners[[i]]
  boot_index_joiners[[i]] <- foreach (b = 1:25, .combine = "rbind") %do% {
    sample(fit$data_dispute$id,
           size = nrow(fit$data_dispute),
           replace = TRUE)
  }
}

if (!dir.exists("results"))
  dir.create("results")
save(fit_joiners, file = "results/fit_joiners.rda")
save(boot_index_joiners, file = "results/boot_index_joiners.rda")


date()
